1 /*
2   D bindings for CUDA.
3   Authors:    Prasun Anand
4   Copyright:  Copyright (c) 2017, Prasun Anand. All rights reserved.
5   License:    BSD 3-Clause License
6 */
7 module cuda_d.cublastXt;
8 
9 import cuda_d.cublas_api;
10 import cuda_d.cuComplex;
11 
12 extern (C):
13 
14 struct cublasXtContext;
15 alias cublasXtHandle_t = cublasXtContext*;
16 
17 cublasStatus_t cublasXtCreate (cublasXtHandle_t* handle);
18 cublasStatus_t cublasXtDestroy (cublasXtHandle_t handle);
19 cublasStatus_t cublasXtGetNumBoards (int nbDevices, int* deviceId, int* nbBoards);
20 cublasStatus_t cublasXtMaxBoards (int* nbGpuBoards);
21 /* This routine selects the Gpus that the user want to use for CUBLAS-XT */
22 cublasStatus_t cublasXtDeviceSelect (cublasXtHandle_t handle, int nbDevices, int* deviceId);
23 
24 /* This routine allows to change the dimension of the tiles ( blockDim x blockDim ) */
25 cublasStatus_t cublasXtSetBlockDim (cublasXtHandle_t handle, int blockDim);
26 cublasStatus_t cublasXtGetBlockDim (cublasXtHandle_t handle, int* blockDim);
27 
28 enum cublasXtPinnedMemMode_t
29 {
30     CUBLASXT_PINNING_DISABLED = 0,
31     CUBLASXT_PINNING_ENABLED = 1
32 }
33 
34 cublasStatus_t cublasXtGetPinningMemMode (cublasXtHandle_t handle, cublasXtPinnedMemMode_t* mode);
35 cublasStatus_t cublasXtSetPinningMemMode (cublasXtHandle_t handle, cublasXtPinnedMemMode_t mode);
36 
37 enum cublasXtOpType_t
38 {
39     CUBLASXT_FLOAT = 0,
40     CUBLASXT_DOUBLE = 1,
41     CUBLASXT_COMPLEX = 2,
42     CUBLASXT_DOUBLECOMPLEX = 3
43 }
44 
45 enum cublasXtBlasOp_t
46 {
47     CUBLASXT_GEMM = 0,
48     CUBLASXT_SYRK = 1,
49     CUBLASXT_HERK = 2,
50     CUBLASXT_SYMM = 3,
51     CUBLASXT_HEMM = 4,
52     CUBLASXT_TRSM = 5,
53     CUBLASXT_SYR2K = 6,
54     CUBLASXT_HER2K = 7,
55 
56     CUBLASXT_SPMM = 8,
57     CUBLASXT_SYRKX = 9,
58     CUBLASXT_HERKX = 10,
59     CUBLASXT_TRMM = 11,
60     CUBLASXT_ROUTINE_MAX = 12
61 }
62 
63 /* Currently only 32-bit integer BLAS routines are supported */
64 cublasStatus_t cublasXtSetCpuRoutine (cublasXtHandle_t handle, cublasXtBlasOp_t blasOp, cublasXtOpType_t type, void* blasFunctor);
65 
66 /* Specified the percentage of work that should done by the CPU, default is 0 (no work) */
67 cublasStatus_t cublasXtSetCpuRatio (cublasXtHandle_t handle, cublasXtBlasOp_t blasOp, cublasXtOpType_t type, float ratio);
68 
69 /* GEMM */
70 cublasStatus_t cublasXtSgemm (
71     cublasXtHandle_t handle,
72     cublasOperation_t transa,
73     cublasOperation_t transb,
74     size_t m,
75     size_t n,
76     size_t k,
77     const(float)* alpha,
78     const(float)* A,
79     size_t lda,
80     const(float)* B,
81     size_t ldb,
82     const(float)* beta,
83     float* C,
84     size_t ldc);
85 
86 cublasStatus_t cublasXtDgemm (
87     cublasXtHandle_t handle,
88     cublasOperation_t transa,
89     cublasOperation_t transb,
90     size_t m,
91     size_t n,
92     size_t k,
93     const(double)* alpha,
94     const(double)* A,
95     size_t lda,
96     const(double)* B,
97     size_t ldb,
98     const(double)* beta,
99     double* C,
100     size_t ldc);
101 
102 cublasStatus_t cublasXtCgemm (
103     cublasXtHandle_t handle,
104     cublasOperation_t transa,
105     cublasOperation_t transb,
106     size_t m,
107     size_t n,
108     size_t k,
109     const(cuComplex)* alpha,
110     const(cuComplex)* A,
111     size_t lda,
112     const(cuComplex)* B,
113     size_t ldb,
114     const(cuComplex)* beta,
115     cuComplex* C,
116     size_t ldc);
117 
118 cublasStatus_t cublasXtZgemm (
119     cublasXtHandle_t handle,
120     cublasOperation_t transa,
121     cublasOperation_t transb,
122     size_t m,
123     size_t n,
124     size_t k,
125     const(cuDoubleComplex)* alpha,
126     const(cuDoubleComplex)* A,
127     size_t lda,
128     const(cuDoubleComplex)* B,
129     size_t ldb,
130     const(cuDoubleComplex)* beta,
131     cuDoubleComplex* C,
132     size_t ldc);
133 
134 /* ------------------------------------------------------- */
135 /* SYRK */
136 cublasStatus_t cublasXtSsyrk (
137     cublasXtHandle_t handle,
138     cublasFillMode_t uplo,
139     cublasOperation_t trans,
140     size_t n,
141     size_t k,
142     const(float)* alpha,
143     const(float)* A,
144     size_t lda,
145     const(float)* beta,
146     float* C,
147     size_t ldc);
148 
149 cublasStatus_t cublasXtDsyrk (
150     cublasXtHandle_t handle,
151     cublasFillMode_t uplo,
152     cublasOperation_t trans,
153     size_t n,
154     size_t k,
155     const(double)* alpha,
156     const(double)* A,
157     size_t lda,
158     const(double)* beta,
159     double* C,
160     size_t ldc);
161 
162 cublasStatus_t cublasXtCsyrk (
163     cublasXtHandle_t handle,
164     cublasFillMode_t uplo,
165     cublasOperation_t trans,
166     size_t n,
167     size_t k,
168     const(cuComplex)* alpha,
169     const(cuComplex)* A,
170     size_t lda,
171     const(cuComplex)* beta,
172     cuComplex* C,
173     size_t ldc);
174 
175 cublasStatus_t cublasXtZsyrk (
176     cublasXtHandle_t handle,
177     cublasFillMode_t uplo,
178     cublasOperation_t trans,
179     size_t n,
180     size_t k,
181     const(cuDoubleComplex)* alpha,
182     const(cuDoubleComplex)* A,
183     size_t lda,
184     const(cuDoubleComplex)* beta,
185     cuDoubleComplex* C,
186     size_t ldc);
187 
188 /* -------------------------------------------------------------------- */
189 /* HERK */
190 cublasStatus_t cublasXtCherk (
191     cublasXtHandle_t handle,
192     cublasFillMode_t uplo,
193     cublasOperation_t trans,
194     size_t n,
195     size_t k,
196     const(float)* alpha,
197     const(cuComplex)* A,
198     size_t lda,
199     const(float)* beta,
200     cuComplex* C,
201     size_t ldc);
202 
203 cublasStatus_t cublasXtZherk (
204     cublasXtHandle_t handle,
205     cublasFillMode_t uplo,
206     cublasOperation_t trans,
207     size_t n,
208     size_t k,
209     const(double)* alpha,
210     const(cuDoubleComplex)* A,
211     size_t lda,
212     const(double)* beta,
213     cuDoubleComplex* C,
214     size_t ldc);
215 
216 /* -------------------------------------------------------------------- */
217 /* SYR2K */
218 cublasStatus_t cublasXtSsyr2k (
219     cublasXtHandle_t handle,
220     cublasFillMode_t uplo,
221     cublasOperation_t trans,
222     size_t n,
223     size_t k,
224     const(float)* alpha,
225     const(float)* A,
226     size_t lda,
227     const(float)* B,
228     size_t ldb,
229     const(float)* beta,
230     float* C,
231     size_t ldc);
232 
233 cublasStatus_t cublasXtDsyr2k (
234     cublasXtHandle_t handle,
235     cublasFillMode_t uplo,
236     cublasOperation_t trans,
237     size_t n,
238     size_t k,
239     const(double)* alpha,
240     const(double)* A,
241     size_t lda,
242     const(double)* B,
243     size_t ldb,
244     const(double)* beta,
245     double* C,
246     size_t ldc);
247 
248 cublasStatus_t cublasXtCsyr2k (
249     cublasXtHandle_t handle,
250     cublasFillMode_t uplo,
251     cublasOperation_t trans,
252     size_t n,
253     size_t k,
254     const(cuComplex)* alpha,
255     const(cuComplex)* A,
256     size_t lda,
257     const(cuComplex)* B,
258     size_t ldb,
259     const(cuComplex)* beta,
260     cuComplex* C,
261     size_t ldc);
262 
263 cublasStatus_t cublasXtZsyr2k (
264     cublasXtHandle_t handle,
265     cublasFillMode_t uplo,
266     cublasOperation_t trans,
267     size_t n,
268     size_t k,
269     const(cuDoubleComplex)* alpha,
270     const(cuDoubleComplex)* A,
271     size_t lda,
272     const(cuDoubleComplex)* B,
273     size_t ldb,
274     const(cuDoubleComplex)* beta,
275     cuDoubleComplex* C,
276     size_t ldc);
277 
278 /* -------------------------------------------------------------------- */
279 /* HERKX : variant extension of HERK */
280 cublasStatus_t cublasXtCherkx (
281     cublasXtHandle_t handle,
282     cublasFillMode_t uplo,
283     cublasOperation_t trans,
284     size_t n,
285     size_t k,
286     const(cuComplex)* alpha,
287     const(cuComplex)* A,
288     size_t lda,
289     const(cuComplex)* B,
290     size_t ldb,
291     const(float)* beta,
292     cuComplex* C,
293     size_t ldc);
294 
295 cublasStatus_t cublasXtZherkx (
296     cublasXtHandle_t handle,
297     cublasFillMode_t uplo,
298     cublasOperation_t trans,
299     size_t n,
300     size_t k,
301     const(cuDoubleComplex)* alpha,
302     const(cuDoubleComplex)* A,
303     size_t lda,
304     const(cuDoubleComplex)* B,
305     size_t ldb,
306     const(double)* beta,
307     cuDoubleComplex* C,
308     size_t ldc);
309 
310 /* -------------------------------------------------------------------- */
311 /* TRSM */
312 cublasStatus_t cublasXtStrsm (
313     cublasXtHandle_t handle,
314     cublasSideMode_t side,
315     cublasFillMode_t uplo,
316     cublasOperation_t trans,
317     cublasDiagType_t diag,
318     size_t m,
319     size_t n,
320     const(float)* alpha,
321     const(float)* A,
322     size_t lda,
323     float* B,
324     size_t ldb);
325 
326 cublasStatus_t cublasXtDtrsm (
327     cublasXtHandle_t handle,
328     cublasSideMode_t side,
329     cublasFillMode_t uplo,
330     cublasOperation_t trans,
331     cublasDiagType_t diag,
332     size_t m,
333     size_t n,
334     const(double)* alpha,
335     const(double)* A,
336     size_t lda,
337     double* B,
338     size_t ldb);
339 
340 cublasStatus_t cublasXtCtrsm (
341     cublasXtHandle_t handle,
342     cublasSideMode_t side,
343     cublasFillMode_t uplo,
344     cublasOperation_t trans,
345     cublasDiagType_t diag,
346     size_t m,
347     size_t n,
348     const(cuComplex)* alpha,
349     const(cuComplex)* A,
350     size_t lda,
351     cuComplex* B,
352     size_t ldb);
353 
354 cublasStatus_t cublasXtZtrsm (
355     cublasXtHandle_t handle,
356     cublasSideMode_t side,
357     cublasFillMode_t uplo,
358     cublasOperation_t trans,
359     cublasDiagType_t diag,
360     size_t m,
361     size_t n,
362     const(cuDoubleComplex)* alpha,
363     const(cuDoubleComplex)* A,
364     size_t lda,
365     cuDoubleComplex* B,
366     size_t ldb);
367 
368 /* -------------------------------------------------------------------- */
369 /* SYMM : Symmetric Multiply Matrix*/
370 cublasStatus_t cublasXtSsymm (
371     cublasXtHandle_t handle,
372     cublasSideMode_t side,
373     cublasFillMode_t uplo,
374     size_t m,
375     size_t n,
376     const(float)* alpha,
377     const(float)* A,
378     size_t lda,
379     const(float)* B,
380     size_t ldb,
381     const(float)* beta,
382     float* C,
383     size_t ldc);
384 
385 cublasStatus_t cublasXtDsymm (
386     cublasXtHandle_t handle,
387     cublasSideMode_t side,
388     cublasFillMode_t uplo,
389     size_t m,
390     size_t n,
391     const(double)* alpha,
392     const(double)* A,
393     size_t lda,
394     const(double)* B,
395     size_t ldb,
396     const(double)* beta,
397     double* C,
398     size_t ldc);
399 
400 cublasStatus_t cublasXtCsymm (
401     cublasXtHandle_t handle,
402     cublasSideMode_t side,
403     cublasFillMode_t uplo,
404     size_t m,
405     size_t n,
406     const(cuComplex)* alpha,
407     const(cuComplex)* A,
408     size_t lda,
409     const(cuComplex)* B,
410     size_t ldb,
411     const(cuComplex)* beta,
412     cuComplex* C,
413     size_t ldc);
414 
415 cublasStatus_t cublasXtZsymm (
416     cublasXtHandle_t handle,
417     cublasSideMode_t side,
418     cublasFillMode_t uplo,
419     size_t m,
420     size_t n,
421     const(cuDoubleComplex)* alpha,
422     const(cuDoubleComplex)* A,
423     size_t lda,
424     const(cuDoubleComplex)* B,
425     size_t ldb,
426     const(cuDoubleComplex)* beta,
427     cuDoubleComplex* C,
428     size_t ldc);
429 
430 /* -------------------------------------------------------------------- */
431 /* HEMM : Hermitian Matrix Multiply */
432 cublasStatus_t cublasXtChemm (
433     cublasXtHandle_t handle,
434     cublasSideMode_t side,
435     cublasFillMode_t uplo,
436     size_t m,
437     size_t n,
438     const(cuComplex)* alpha,
439     const(cuComplex)* A,
440     size_t lda,
441     const(cuComplex)* B,
442     size_t ldb,
443     const(cuComplex)* beta,
444     cuComplex* C,
445     size_t ldc);
446 
447 cublasStatus_t cublasXtZhemm (
448     cublasXtHandle_t handle,
449     cublasSideMode_t side,
450     cublasFillMode_t uplo,
451     size_t m,
452     size_t n,
453     const(cuDoubleComplex)* alpha,
454     const(cuDoubleComplex)* A,
455     size_t lda,
456     const(cuDoubleComplex)* B,
457     size_t ldb,
458     const(cuDoubleComplex)* beta,
459     cuDoubleComplex* C,
460     size_t ldc);
461 
462 /* -------------------------------------------------------------------- */
463 /* SYRKX : variant extension of SYRK  */
464 cublasStatus_t cublasXtSsyrkx (
465     cublasXtHandle_t handle,
466     cublasFillMode_t uplo,
467     cublasOperation_t trans,
468     size_t n,
469     size_t k,
470     const(float)* alpha,
471     const(float)* A,
472     size_t lda,
473     const(float)* B,
474     size_t ldb,
475     const(float)* beta,
476     float* C,
477     size_t ldc);
478 
479 cublasStatus_t cublasXtDsyrkx (
480     cublasXtHandle_t handle,
481     cublasFillMode_t uplo,
482     cublasOperation_t trans,
483     size_t n,
484     size_t k,
485     const(double)* alpha,
486     const(double)* A,
487     size_t lda,
488     const(double)* B,
489     size_t ldb,
490     const(double)* beta,
491     double* C,
492     size_t ldc);
493 
494 cublasStatus_t cublasXtCsyrkx (
495     cublasXtHandle_t handle,
496     cublasFillMode_t uplo,
497     cublasOperation_t trans,
498     size_t n,
499     size_t k,
500     const(cuComplex)* alpha,
501     const(cuComplex)* A,
502     size_t lda,
503     const(cuComplex)* B,
504     size_t ldb,
505     const(cuComplex)* beta,
506     cuComplex* C,
507     size_t ldc);
508 
509 cublasStatus_t cublasXtZsyrkx (
510     cublasXtHandle_t handle,
511     cublasFillMode_t uplo,
512     cublasOperation_t trans,
513     size_t n,
514     size_t k,
515     const(cuDoubleComplex)* alpha,
516     const(cuDoubleComplex)* A,
517     size_t lda,
518     const(cuDoubleComplex)* B,
519     size_t ldb,
520     const(cuDoubleComplex)* beta,
521     cuDoubleComplex* C,
522     size_t ldc);
523 
524 /* -------------------------------------------------------------------- */
525 /* HER2K : variant extension of HERK  */
526 cublasStatus_t cublasXtCher2k (
527     cublasXtHandle_t handle,
528     cublasFillMode_t uplo,
529     cublasOperation_t trans,
530     size_t n,
531     size_t k,
532     const(cuComplex)* alpha,
533     const(cuComplex)* A,
534     size_t lda,
535     const(cuComplex)* B,
536     size_t ldb,
537     const(float)* beta,
538     cuComplex* C,
539     size_t ldc);
540 
541 cublasStatus_t cublasXtZher2k (
542     cublasXtHandle_t handle,
543     cublasFillMode_t uplo,
544     cublasOperation_t trans,
545     size_t n,
546     size_t k,
547     const(cuDoubleComplex)* alpha,
548     const(cuDoubleComplex)* A,
549     size_t lda,
550     const(cuDoubleComplex)* B,
551     size_t ldb,
552     const(double)* beta,
553     cuDoubleComplex* C,
554     size_t ldc);
555 
556 /* -------------------------------------------------------------------- */
557 /* SPMM : Symmetric Packed Multiply Matrix*/
558 cublasStatus_t cublasXtSspmm (
559     cublasXtHandle_t handle,
560     cublasSideMode_t side,
561     cublasFillMode_t uplo,
562     size_t m,
563     size_t n,
564     const(float)* alpha,
565     const(float)* AP,
566     const(float)* B,
567     size_t ldb,
568     const(float)* beta,
569     float* C,
570     size_t ldc);
571 
572 cublasStatus_t cublasXtDspmm (
573     cublasXtHandle_t handle,
574     cublasSideMode_t side,
575     cublasFillMode_t uplo,
576     size_t m,
577     size_t n,
578     const(double)* alpha,
579     const(double)* AP,
580     const(double)* B,
581     size_t ldb,
582     const(double)* beta,
583     double* C,
584     size_t ldc);
585 
586 cublasStatus_t cublasXtCspmm (
587     cublasXtHandle_t handle,
588     cublasSideMode_t side,
589     cublasFillMode_t uplo,
590     size_t m,
591     size_t n,
592     const(cuComplex)* alpha,
593     const(cuComplex)* AP,
594     const(cuComplex)* B,
595     size_t ldb,
596     const(cuComplex)* beta,
597     cuComplex* C,
598     size_t ldc);
599 
600 cublasStatus_t cublasXtZspmm (
601     cublasXtHandle_t handle,
602     cublasSideMode_t side,
603     cublasFillMode_t uplo,
604     size_t m,
605     size_t n,
606     const(cuDoubleComplex)* alpha,
607     const(cuDoubleComplex)* AP,
608     const(cuDoubleComplex)* B,
609     size_t ldb,
610     const(cuDoubleComplex)* beta,
611     cuDoubleComplex* C,
612     size_t ldc);
613 
614 /* -------------------------------------------------------------------- */
615 /* TRMM */
616 cublasStatus_t cublasXtStrmm (
617     cublasXtHandle_t handle,
618     cublasSideMode_t side,
619     cublasFillMode_t uplo,
620     cublasOperation_t trans,
621     cublasDiagType_t diag,
622     size_t m,
623     size_t n,
624     const(float)* alpha,
625     const(float)* A,
626     size_t lda,
627     const(float)* B,
628     size_t ldb,
629     float* C,
630     size_t ldc);
631 
632 cublasStatus_t cublasXtDtrmm (
633     cublasXtHandle_t handle,
634     cublasSideMode_t side,
635     cublasFillMode_t uplo,
636     cublasOperation_t trans,
637     cublasDiagType_t diag,
638     size_t m,
639     size_t n,
640     const(double)* alpha,
641     const(double)* A,
642     size_t lda,
643     const(double)* B,
644     size_t ldb,
645     double* C,
646     size_t ldc);
647 
648 cublasStatus_t cublasXtCtrmm (
649     cublasXtHandle_t handle,
650     cublasSideMode_t side,
651     cublasFillMode_t uplo,
652     cublasOperation_t trans,
653     cublasDiagType_t diag,
654     size_t m,
655     size_t n,
656     const(cuComplex)* alpha,
657     const(cuComplex)* A,
658     size_t lda,
659     const(cuComplex)* B,
660     size_t ldb,
661     cuComplex* C,
662     size_t ldc);
663 
664 cublasStatus_t cublasXtZtrmm (
665     cublasXtHandle_t handle,
666     cublasSideMode_t side,
667     cublasFillMode_t uplo,
668     cublasOperation_t trans,
669     cublasDiagType_t diag,
670     size_t m,
671     size_t n,
672     const(cuDoubleComplex)* alpha,
673     const(cuDoubleComplex)* A,
674     size_t lda,
675     const(cuDoubleComplex)* B,
676     size_t ldb,
677     cuDoubleComplex* C,
678     size_t ldc);